Conversation
|
Very nice! I think having this performance boost makes the usage of previous solutions obsolete - at least partially, because using the previous solutions also gives us some sort of stability over the curves (execution-dimension). Can we make it an (default) option and keep the other version? |
we may keep the old code, but in that case i suggest adding a docstring (or even a small section in the docs) to highlight that our current version (sequentially solving for functions) is way slower than using the vectorized version. def __init__(self, f, w0, max_fit_attemps, vectorized=True):in |
Fully agree! Sounds good! |
| self.max_fit_attempts = max_fit_attemps | ||
|
|
||
| def _latents_to_array(self, latents): | ||
| l_x0_mat = jnp.array([latent.x0 for latent in latents]) |
There was a problem hiding this comment.
something like would be a bit more code efficient (while harder to understand though)
return (
jnp.array([attr(latent, "dim") for latent in latents]) for dim in ["x0", "x1",....]
)
Done in this PR
Implemented a faster implementation of curve sampling in
JaxCurveGenerationSolverusingjax.vmap.This is much faster according to my benchmarks.
Tested on 10k curves, meaning 10k latent informations which need to be solved.
With the old code, the computation succeded in 268 seconds, the vectorized code succeeds in 8 seconds.
But this goes with two downsides:
w0.callbacksince the whole matrix is computed at once. This might be a less of an issue, since we still get all results, but much faster and not computed sequentially. Maybe we should provide a way to get the losses for each instance?@windisch What do you think?